Dealing With PyTorch Custom Datasets.

In this article, we are going to take a look at How to deal with Custom PyTorch Dataset.

Photo by Joshua Earle on Unsplash

Custom datasets!! WHY??

Because you can shape it in a way you desire!!!

It is natural that we will develop our way of creating custom datasets while dealing with different Projects.

There are some official custom dataset examples on PyTorch Like here but it seemed a bit obscure to a beginner (like me, back then). The topics which we will discuss are as follows.

  1. Custom Dataset Fundamentals.
  1. Using Torchvision Transforms.
  1. Dealing with pandas (read_csv)
  1. Embedding Classes into File Names
  1. Using DataLoader

1. Custom Dataset Fundamentals.

A dataset must contain the following functions to be used by DataLoader later on.

Now, the first part is to create a dataset class:

from torch.utils.data.dataset import Dataset

class MyCustomDataset(Dataset):
    def __init__(self, ...):
        # stuff
        
    def __getitem__(self, index):
        # stuff
        return (img, label)

    def __len__(self):
        return count # of how many examples(images?) you have

Here, MyCustomDataset returns two things, an image, and its label. But this doesn’t mean that __getitem__() is only restricted to return those.

NOTE:

TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'PIL.PngImagePlugin.PngImageFile'>

2. Using Torchvision Transforms:

In most of the examples, we will see transforms = None in the __init__() , it is to apply Torchvision transforms for our data/image. You can find the list of all transforms here.

Most Common usage of transforms are:

from torch.utils.data.dataset import Dataset
from torchvision import transforms

class MyCustomDataset(Dataset):
    def __init__(self, ..., transforms=None):
        # stuff
        ...
        self.transforms = transforms
        
    def __getitem__(self, index):
        # stuff
        ...
        data = # Some data read from a file or image
        if self.transforms is not None:
            data = self.transforms(data)
        # If the transform variable is not empty
        # then it applies the operations in the transforms with the order that it is created.
        return (img, label)

    def __len__(self):
        return count # of how many data(images?) you have
        
if __name__ == '__main__':
    # Define transforms (1)
    transformations = transforms.Compose([transforms.CenterCrop(100), transforms.ToTensor()])
    # Call the dataset
    custom_dataset = MyCustomDataset(..., transformations)

You can define the transforms inside the Dataset class. Like this:

from torch.utils.data.dataset import Dataset
from torchvision import transforms

class MyCustomDataset(Dataset):
    def __init__(self, ...):
        # stuff
        ...
        # (2) One way to do it is define transforms individually
        self.center_crop = transforms.CenterCrop(100)
        self.to_tensor = transforms.ToTensor()
        
        # (3) Or you can still compose them like 
        self.transformations = \
            transforms.Compose([transforms.CenterCrop(100),
                                transforms.ToTensor()])
        
    def __getitem__(self, index):
        # stuff
        ...
        data = # Some data read from a file or image
        
# When you call transform for the second time it calls __call__()               and applies the transform 
        data = self.center_crop(data)  # (2)
        data = self.to_tensor(data)  # (2)
        
        # Or you can call the composed version
        data = self.transformations(data)  # (3)
        
     # Note that you only need one of the implementations,(2) or (3)
        return (img, label)

    def __len__(self):
        return count # of how many data(images?) you have
        
if __name__ == '__main__':
    # Call the dataset
    custom_dataset = MyCustomDataset(...)

3. Dealing with Pandas(read_csv):

Now our Dataset contains a file name, label, and an extra operation indicator, we’ll perform some extra operation on the image.

+-----------+-------+-----------------+
| File Name | Label | Extra Operation |
+-----------+-------+-----------------+
| tr_0.png  |     5 | TRUE            |
| tr_1.png  |     0 | FALSE           |
| tr_1.png  |     4 | FALSE           |
+-----------+-------+-----------------+

Building a Custom dataset that reads image locations from this CSV.

class CustomDatasetFromImages(Dataset):
    def __init__(self, csv_path):
        """
        Args:
            csv_path (string): path to csv file
            img_path (string): path to the folder where images are
            transform: pytorch transforms for transforms and tensor conversion
        """
# Transforms
        self.to_tensor = transforms.ToTensor()
        # Read the csv file
        self.data_info = pd.read_csv(csv_path, header=None)
        # First column contains the image paths
        self.image_arr = np.asarray(self.data_info.iloc[:, 0])
        # Second column is the labels
        self.label_arr = np.asarray(self.data_info.iloc[:, 1])
        # Third column is for an operation indicator
        self.operation_arr = np.asarray(self.data_info.iloc[:, 2])
        # Calculate len
        self.data_len = len(self.data_info.index)

    def __getitem__(self, index):
        # Get image name from the pandas df
        single_image_name = self.image_arr[index]
        # Open image
        img_as_img = Image.open(single_image_name)

        # Check if there is an operation
        some_operation = self.operation_arr[index]
        # If there is an operation
        if some_operation:
            # Do some operation on image
            # ...
            # ...
            pass
        # Transform image to tensor
        img_as_tensor = self.to_tensor(img_as_img)

        # Get label of the image based on the cropped pandas column
        single_image_label = self.label_arr[index]

        return (img_as_tensor, single_image_label)

    def __len__(self):
        return self.data_len

if __name__ == "__main__":
    # Call dataset
    custom_mnist_from_images =  \
        CustomDatasetFromImages('../data/mnist_labels.csv')

Another example of reading an image from CSV where the value of each pixel is listed in the Columns(Eg., MNIST). A little change of logic in __getitem__() . In the end, we’ll just return images as Tensors and their labels. The data looks like this,

+-------+---------+---------+-----+
| Label | pixel_1 | pixel_2 | ... |
+-------+---------+---------+-----+
|     1 |      50 |      99 | ... |
|     0 |      21 |     223 | ... |
|     9 |      44 |     112 |     |
+-------+---------+---------+-----+

Now, the code looks like:

class CustomDatasetFromCSV(Dataset):
    def __init__(self, csv_path, height, width, transforms=None):
        """
        Args:
            csv_path (string): path to csv file
            height (int): image height
            width (int): image width
            transform: pytorch transforms for transforms and tensor conversion
        """
        self.data = pd.read_csv(csv_path)
        self.labels = np.asarray(self.data.iloc[:, 0])
        self.height = height
        self.width = width
        self.transforms = transform

    def __getitem__(self, index):
        single_image_label = self.labels[index]
        # Read each 784 pixels and reshape the 1D array ([784]) to 2D array ([28,28]) 
        img_as_np = np.asarray(self.data.iloc[index][1:]).reshape(28,28).astype('uint8')
	# Convert image from numpy array to PIL image, mode 'L' is for grayscale
        img_as_img = Image.fromarray(img_as_np)
        img_as_img = img_as_img.convert('L')
        # Transform image to tensor
        if self.transforms is not None:
            img_as_tensor = self.transforms(img_as_img)
        # Return image and the label
        return (img_as_tensor, single_image_label)

    def __len__(self):
        return len(self.data.index)
        

if __name__ == "__main__":
    transformations = transforms.Compose([transforms.ToTensor()])
    custom_mnist_from_csv = \
        CustomDatasetFromCSV('../data/mnist_in_csv.csv', 28, 28, transformations)

4. Embedding Class names as File Names:

Using Folder names of the images as the File_Names:

class CustomDatasetFromFile(Dataset):
    def __init__(self, folder_path):
        """
        A dataset example where the class is embedded in the file names
        This data example also does not use any torch transforms

        Args:
            folder_path (string): path to image folder
        """
        # Get image list
        self.image_list = glob.glob(folder_path+'*')
        # Calculate len
        self.data_len = len(self.image_list)

    def __getitem__(self, index):
        # Get image name from the pandas df
        single_image_path = self.image_list[index]
        # Open image
        im_as_im = Image.open(single_image_path)

        # Do some operations on image
        # Convert to numpy, dim = 28x28
        im_as_np = np.asarray(im_as_im)/255
        # Add channel dimension, dim = 1x28x28
        # Note: You do not need to do this if you are reading RGB images
        # or i there is already channel dimension
        im_as_np = np.expand_dims(im_as_np, 0)
        # Some preprocessing operations on numpy array
        # ...
        # ...
        # ...

        # Transform image to tensor, change data type
        im_as_ten = torch.from_numpy(im_as_np).float()

        # Get label(class) of the image based on the file name
        class_indicator_location = single_image_path.rfind('_c')
        label = int(single_image_path[class_indicator_location+2:class_indicator_location+3])
        return (im_as_ten, label)

    def __len__(self):
        return self.data_len

5. Using DataLoader:

PyTorch DataLoaders will call __getitem__() and wrap them up into a batch. But Technically, we will not use DataLoaders and call __getitem__() one at a time and feed data into the models. Now, we can call the DataLoader like:

...
if __name__ == "__main__":
    # Define transforms
    transformations = transforms.Compose([transforms.ToTensor()])
    # Define custom dataset
    custom_mnist_from_csv = \
        CustomDatasetFromCSV('../data/mnist_in_csv.csv',
                             28, 28,
                             transformations)
    # Define data loader
    mn_dataset_loader = torch.utils.data.DataLoader(dataset=custom_mnist_from_csv,
                                                    batch_size=10,
                                                    shuffle=False)
    
    for images, labels in mn_dataset_loader:
        # Feed the data to the model

Here, batch_size decides how many individual data points will be wrapped in a single batch. The DataLoader will return a Tensor of shape (Batch — Depth — Height — Width)

tensor.shape(10x1x28x28) # if batch_size =10 (For MNIST Data). 

That’s it!!!

Custom Datasets!! No Worries!!

Reference: